Skip to content

Conversation

@albertbou92
Copy link
Contributor

@albertbou92 albertbou92 commented Nov 15, 2022

Description

Added an A2C objective class.

I also created the helper functions necessary to run an A2C example, including make_a2c_loss, A2CLossConfig, make_a2c_model, A2CModelConfig

Creating a make_a2c_model helper function was not strictly necessary since the models are the same as in PPO. However, I wanted to use less nodes in the hidden layers so I decided to create a make_a2c_model instead of modifying the make_ppo_model. The methods can probably be merged in the future if necessary, and the architecture of the networks can be passed as a parameter.

Finally, I played a bit with the parameters int he canfig.yaml file until I found a good enough configuration that learned pretty well in the HalfCheetah-v4 environment.

Motivation and Context

There is an open issue about A2C, and while it is similar to REINFORCE and PPO which are already in the repo, the objective is not the same. In particular, it has the entropy term (which is not present in REINFORCE) and it does not have the log prob ratio weighting term, the clipping and the KL term present in PPO.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 15, 2022
@albertbou92 albertbou92 changed the title A2C objective class and train example [Feature] A2C objective class and train example Nov 15, 2022
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind merging main and trying to solve the issues with the new "next" logic? Let me also know what you think of it :)

vmoens and others added 9 commits November 16, 2022 20:13
* init

* strict=False

* amend

* amend
* Add auto-compute stats feature for ObservationNorm

* Fix issue in ObservNorm init function

* Quick refactor of ObservationNorm init method

* Minor refactoring and adding more tests for ObservationNorm

* lint

* docstring

* docstring

Co-authored-by: vmoens <vincentmoens@gmail.com>
* init

* [Feature] Nested composite spec (pytorch#654)

* [Feature] Move `transform.forward` to `transform.step` (pytorch#660)

* transform step function

* amend

* amend

* amend

* amend

* amend

* fixing key names

* fixing key names

* [Refactor] Transform next remove (pytorch#661)

* Refactor "next_" into ("next", ) (pytorch#673)

* amend

* amend

* bugfix

* init

* strict=False

* strict=False

* minor

* amend

* [BugFix] Use GitHub for flake8 pre-commit hook (pytorch#679)

* amend

* [BugFix] Update to strict select (pytorch#675)

* init

* strict=False

* amend

* amend

* [Feature] Auto-compute stats for ObservationNorm (pytorch#669)

* Add auto-compute stats feature for ObservationNorm

* Fix issue in ObservNorm init function

* Quick refactor of ObservationNorm init method

* Minor refactoring and adding more tests for ObservationNorm

* lint

* docstring

* docstring

Co-authored-by: vmoens <vincentmoens@gmail.com>

* amend

* amend

* lint

* bf

* bf

* amend

Co-authored-by: Romain Julien <romainjulien@fb.com>

Co-authored-by: Romain Julien <romainjulien@fb.com>
@albertbou92
Copy link
Contributor Author

albertbou92 commented Nov 17, 2022

Done! I brought all the changes from main, and now the training script calculates the initial Stats with the key "observation_vector" instead of "next_observation_vector". It should be the same since it is actually the same tensor delayed by 1 timestep. I also checked that the example script runs without issues.

@vmoens
Copy link
Collaborator

vmoens commented Nov 18, 2022

It feels like you merged each diff independently, which makes a gigantic diff here (over 50 files changed)
Can you run git merge main and see what happens?

vmoens and others added 5 commits November 18, 2022 12:52
* init

* [Feature] Nested composite spec (pytorch#654)

* [Feature] Move `transform.forward` to `transform.step` (pytorch#660)

* transform step function

* amend

* amend

* amend

* amend

* amend

* fixing key names

* fixing key names

* [Refactor] Transform next remove (pytorch#661)

* Refactor "next_" into ("next", ) (pytorch#673)

* amend

* amend

* bugfix

* init

* strict=False

* strict=False

* minor

* amend

* [BugFix] Use GitHub for flake8 pre-commit hook (pytorch#679)

* amend

* [BugFix] Update to strict select (pytorch#675)

* init

* strict=False

* amend

* amend

* [Feature] Auto-compute stats for ObservationNorm (pytorch#669)

* Add auto-compute stats feature for ObservationNorm

* Fix issue in ObservNorm init function

* Quick refactor of ObservationNorm init method

* Minor refactoring and adding more tests for ObservationNorm

* lint

* docstring

* docstring

Co-authored-by: vmoens <vincentmoens@gmail.com>

* amend

* amend

* lint

* bf

* bf

* amend

Co-authored-by: Romain Julien <romainjulien@fb.com>

Co-authored-by: Romain Julien <romainjulien@fb.com>
* amend

* amend

* amend

* amend

* amend

* amend
@vmoens vmoens added the new algo New algorithm request or PR label Nov 18, 2022
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. The lint test is failing, a pre-commit should solve that. We should consider adding this to the example test pipeline (#687). After that I think we'll be good to go!

@codecov
Copy link

codecov bot commented Nov 21, 2022

Codecov Report

Merging #680 (60c2730) into main (170c6f3) will increase coverage by 0.09%.
The diff coverage is 93.00%.

@@            Coverage Diff             @@
##             main     #680      +/-   ##
==========================================
+ Coverage   87.78%   87.88%   +0.09%     
==========================================
  Files         119      120       +1     
  Lines       20201    20590     +389     
==========================================
+ Hits        17733    18095     +362     
- Misses       2468     2495      +27     
Flag Coverage Δ
habitat-gpu 24.20% <22.53%> (?)
linux-cpu 85.58% <88.00%> (+0.04%) ⬆️
linux-gpu 86.28% <88.00%> (+0.03%) ⬆️
linux-outdeps-gpu 72.46% <55.75%> (-0.36%) ⬇️
linux-stable-cpu 85.44% <88.00%> (+0.04%) ⬆️
linux-stable-gpu 86.14% <88.00%> (+0.03%) ⬆️
macos-cpu 85.31% <88.00%> (+0.05%) ⬆️
olddeps-gpu 75.33% <88.00%> (+0.24%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchrl/trainers/helpers/__init__.py 100.00% <ø> (ø)
torchrl/trainers/helpers/models.py 91.30% <82.10%> (-1.98%) ⬇️
test/test_helpers.py 91.59% <91.78%> (+0.03%) ⬆️
test/test_cost.py 96.95% <96.47%> (-0.06%) ⬇️
torchrl/objectives/__init__.py 100.00% <100.00%> (ø)
torchrl/objectives/a2c.py 100.00% <100.00%> (ø)
torchrl/trainers/helpers/losses.py 39.83% <100.00%> (+9.43%) ⬆️
torchrl/modules/models/exploration.py 91.79% <0.00%> (+0.51%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of minor changes and we're good to go!

@vmoens vmoens merged commit 9a81a97 into pytorch:main Nov 23, 2022
@albertbou92 albertbou92 deleted the a2c branch November 30, 2022 14:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. new algo New algorithm request or PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants